# coding=utf-8
'''
author: chensheng
created: 2022/09/07
desc: 该脚本用于解析用户传入TCLR中的所有参数，并且传入start函数中,注意，传入的所有参数都是str类型
'''

import argparse
import sys

from TCLR import TCLRalgorithm as model


def str_to_bool(str_arg):
    # 用于将输入参数转化为bool类型
    if str_arg == "False":
        bool_res = False
    elif str_arg == "True":
        bool_res = True
    else:
        print("It must be bool type!")
        sys.exit()
    return bool_res


def str_to_list(str_arg):
    # 用于将输入参数转化为list类型
    # list中不会出现引号以及中括号
    if str_arg == None or str_arg == 'None':
        return None

    res_list = []
    index = 1
    while(str_arg[index] != ']'):
        if str_arg[index] == "\'" or str_arg[index] == '\"':
            index = index + 1
            start_index = index
            test_start = str_arg[index]
            while(str_arg[index] != "\'" and str_arg[index] != '\"'):
                index = index + 1
            end_index = index
            test_end = str_arg[index]
            tmp_item = str_arg[start_index:end_index]
            res_list.append(tmp_item)
        index = index + 1
    return res_list


def str_to_list_for_tolerance_list(str_arg):
    # 先去除str中的空格
    str_arg.replace(" ", "")
    res_list = []
    step = 0
    index = 0
    str_arg = str_arg[1:-1]
    while(len(str_arg) > step):
        step = step + 1
        if str_arg[index] == '[':
            start_index = index
        if str_arg[index] == ']':
            end_index = index
            tmp_item = str_arg[start_index:end_index+1]
            tmp_item = tmpstr_to_list(tmp_item)
            res_list.append(tmp_item)
        index = index + 1
    return res_list


def tmpstr_to_list(str_arg):
    # 用来对tolerance_list中的list进行转换 str->list
    str_arg.replace(" ", "")
    res_list = []
    index = 2  # 跳过前面的'['和"'"
    start_index = index
    while(str_arg[index] != "'"):
        index = index + 1
    tmp_item = str_arg[start_index:index]
    res_list.append(tmp_item)
    start_index = index + 2
    while(str_arg[index] != "]"):
        index = index + 1
    tmp_item = str_arg[start_index:index]
    res_list.append(float(tmp_item))
    return res_list


def path_remake(path):
    # 对路径进行转化
    return path.replace(' ', '\ ').replace('(', '\(').replace(')', '\)')


# 从shell脚本传来的都已经设定好了默认值，且都为str
parser = argparse.ArgumentParser()
parser.add_argument("--filePath", dest="filePath", type=str)
parser.add_argument("--correlation", dest="correlation", type=str)
parser.add_argument("--tolerance_list", dest="tolerance_list", type=str)
parser.add_argument("--gpl_dummyfea", dest="gpl_dummyfea", type=str)
parser.add_argument("--minsize", dest="minsize", type=str)
parser.add_argument("--threshold", dest="threshold", type=str)
parser.add_argument("--mininc", dest="mininc", type=str)
parser.add_argument("--split_tol", dest="split_tol", type=str)
parser.add_argument("--gplearn", dest="gplearn", type=str)
parser.add_argument("--population_size", dest="population_size", type=str)
parser.add_argument("--generations", dest="generations", type=str)
parser.add_argument("--verbose", dest="verbose", type=str)
parser.add_argument("--metric", dest="metric", type=str)
parser.add_argument("--function_set", dest="function_set", type=str)

args = parser.parse_args()
filePath = path_remake(args.filePath)
correlation = str(args.correlation)
correlation = "PearsonR(+)"
tolerance_list = str_to_list_for_tolerance_list(args.tolerance_list)
gpl_dummyfea = str_to_list(args.gpl_dummyfea)
minsize = int(args.minsize)
threshold = float(args.threshold)
mininc = float(args.mininc)
split_tol = float(args.split_tol)
gplearn = str_to_bool(args.gplearn)
population_size = int(args.population_size)
generations = int(args.generations)
verbose = int(args.verbose)
metric = args.metric
function_set = str_to_list(args.function_set)

model.start(filePath=filePath, correlation=correlation, tolerance_list=tolerance_list, gpl_dummyfea=gpl_dummyfea, minsize=minsize, threshold=threshold, mininc=mininc,
            split_tol=split_tol, gplearn=gplearn, population_size=population_size, generations=generations, verbose=verbose, metric=metric, function_set=function_set)

'''
dataSet = "testdata.csv"
correlation = 'PearsonR(+)'
minsize = 3
threshold = 0.9
mininc = 0.01
split_tol = 0.8


model.start(filePath = dataSet, correlation = correlation, minsize = minsize, threshold = threshold,mininc = mininc ,split_tol = split_tol,)
'''
